
from isq import LocalDevice, optv
import jax
import time
import jax.numpy as jnp

#定义量子线路
isq_str = '''
    qbit q[2];
    RX(theta[0], q[0]);
    RY(theta[1], q[1]);
    M(q[0]);
'''

#定义device
ld = LocalDevice()

#定义计算函数
def calc(params):
    #使用probs函数进行带参编译和模拟，其中需要优化的参数通过optv封装后传入，mod值为1，使用jax.numpy计算
    c = ld.probs(isq_str, mod = 1, theta = optv(params))
    return c[0] - c[1]

theta = jnp.array([[0.2, 0.4] for _ in range(10)])

t = time.time()
g = jax.grad(calc)
c = jax.vmap(g)
d = c(theta)
e = time.time()
print(e-t)
print(d)